#!D:/Code/python
# -*- coding: utf-8 -*-
# @Time : 2021/5/7 0007 20:20
# @Author : xgf
# @File : method_DL_template.py
# @Software : PyCharm
import numpy

import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
    TypeVar, Type, Union, Optional, Any,
    List, Dict, Tuple, Callable, NamedTuple
)

import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from utils import Args, D, timeit

from log import Log
method_name = 'encrypted-traffic-analysis-2021'
mylog = Log('../encrypted-traffic-analysis-2021_log', method_name)

def get_args() -> Args:
    """
    获取参数
    """
    default_raw_data_dir = os.path.join(os.path.dirname(
        __file__), "dataset/traffic")
    default_feature_data_dir = os.path.join(os.path.dirname(
        __file__), "dataset/traindata")
    return Args([
        D("batchSize", int, 32),
        D("learningRate", float, 1e-3),
        D("numEpochs", int, 1000),
        D("rawDataDir", str, default_raw_data_dir),
        D("dataDir", str, default_feature_data_dir),
        D("saveDir", str, None),
        D("nClass", int, 50),
        D('splitData', int, 0.8),
    ])

def readNpy(file_path):
    """
    读取npy文件
    @param file_path:npy文件路径
    @return:读取的npy内容，内容具体格式未知
    """
    return np.load(file_path, allow_pickle=True)

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class EncryptTrafficDataset(Dataset):
    """
    特征数据集
    """
    def __init__(self, traffic_data, transform=None, target_transform=None):
        traffic_data = np.array(traffic_data)
        traffic_data = traffic_data[:, :2]
        try:
            self.traffic_features = traffic_data[:, 0]
            self.traffic_labels = traffic_data[:, 1]
        except IOError:
            print("EncryptTrafficDataset初始化数据集失败，因为数据集传入错误")

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.traffic_labels)

    def __getitem__(self, idx):
        feature = self.traffic_features[idx]
        label = self.traffic_labels[idx]
        if self.transform:
            feature = self.transform(feature)
        if self.target_transform:
            label = self.target_transform(label)

        feature = Tensor(feature)
        # print(label)
        # sample = {"feature": feature, "label": label, "lag": lag}
        sample = {"feature": feature, "label": label}
        return (feature, label)

class ThreeLinearNetwork(nn.Module):
    """
    定义一个简单的三层神经网络用于测试
    """
    def __init__(self):
        super(ThreeLinearNetwork, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            # nn.Linear(30, 512),
            # nn.ReLU(),
            # nn.Linear(512, 1024),
            # nn.ReLU(),
            # nn.Linear(1024, 512),
            # nn.ReLU(),
            # nn.Linear(512, 50),
            # nn.ReLU()
            nn.Linear(30, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 50),
            nn.ReLU()
        )

    def forward(self, x):
        # x = self.flatten(x)
        # x = x.float()
        logits = self.linear_relu_stack(x)
        return logits

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        y.long()
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            y.long()
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct

if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))

    args = get_args()

    batch_size = args.batchSize  # 批的大小
    learning_rate = args.learningRate  # 学习率
    num_epochs = args.numEpochs  # 遍历训练集的次数
    data_dir = args.dataDir
    save_dir = args.saveDir
    n_class = args.nClass
    split_data = args.splitData

    # # 读取数据
    # # data = readNpy('./feature_extraction/undefence_features.npy')
    traindata = readNpy('./feature_extraction/undefence_90.npy')
    testdata = readNpy('./feature_extraction/undefence_10.npy')
    # traffic_data = EncryptTrafficDataset(data)
    # print(traffic_data[0])
    # # 划分数据
    # train_size = int(split_data * len(traffic_data))
    # test_size = len(traffic_data) - train_size
    # train_data, test_data = torch.utils.data.random_split(traffic_data, [train_size, test_size])
    train_data, test_data = EncryptTrafficDataset(traindata), EncryptTrafficDataset(testdata)
    # 定义dataloader
    train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=True)
    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    test_dataloader = DataLoader(test_data, batch_size= batch_size, shuffle=True)

    model = ThreeLinearNetwork().to(device)
    print(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr= learning_rate)

    for t in range(num_epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loss = train(train_dataloader, model, loss_fn, optimizer)
        train_loss, train_test = test(train_dataloader, model)
        test_loss, test_acc = test(test_dataloader, model)
        mylog.state_dict_update([('train_loss_list', train_loss),
                                 ('train_acc_list', test_acc),
                                 ('valid_loss_list', test_loss),
                                 ('valid_acc_list', test_acc),
                                 ])
    print("Done!")